import torch 
from torch import nn 
from torch.nn import functional as F 
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup,
    BartModel,
    
)
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput



class BartGen(PreTrainedModel):
    def __init__(self, config, tokenizer):
        super(BartGen, self).__init__(config)
        self.config = config 
        self.tokenizer = tokenizer 
        self.transformer = BartModel.from_pretrained('facebook/bart-large')
        self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings)))

    def resize_token_embeddings(self):
        old_num_tokens = self.transformer.shared.num_embeddings
        new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer))
        self.transformer.shared = new_embeddings
        self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens)
        self.vocab_size = len(self.tokenizer) 

        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)


    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm 
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


    def get_encoder(self):
        return self.transformer.encoder 

        
    def get_output_embeddings(self):
        # this method is needed for generation
        vocab_size, emb_size = self.transformer.shared.weight.shape
        lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
        lin_layer.weight.data = self.transformer.shared.weight.data
        return lin_layer 


    def prepare_inputs_for_generation(
        self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
    ):
        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def adjust_logits_during_generation(self, logits, cur_len, max_length):
        if cur_len == 1 and self.config.force_bos_token_to_be_generated:
            self._force_token_ids_generation(logits, self.config.bos_token_id)
        elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
            self._force_token_ids_generation(logits, self.config.eos_token_id)
        return logits

    def _force_token_ids_generation(self, scores, token_id) -> None:
        """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
        scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")

    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = []
        for layer_past in past:
            # get the correct batch idx from decoder layer's batch dim for cross and self-attn
            layer_past_new = {
                attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
            }
            reordered_past.append(layer_past_new)
        return reordered_past
    
    

    def forward(self, input_ids, 
        attention_mask=None, 
        encoder_outputs=None, 
        use_cache=False,
        past_key_values=None,
        decoder_input_ids=None, 
        decoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None, 
        task=-1):

        # generation
        if task==-1:
            outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            use_cache=use_cache, 
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,)

            lm_logits = F.linear(outputs[0], self.transformer.shared.weight, bias=self.final_logits_bias)
            masked_lm_loss = None
            
            if not return_dict:
                output = (lm_logits,) + outputs[1:]
                return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

            return Seq2SeqLMOutput(
                loss=masked_lm_loss,
                logits=lm_logits,
                past_key_values=outputs.past_key_values,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )
            
        #training 
        elif task==0:
            
            assert(decoder_input_ids!=None)
            y_ids = decoder_input_ids[:, :-1] 
            labels = decoder_input_ids[:, 1:].clone() 
            labels[labels== self.tokenizer.pad_token_id] = -100 
            # labels are just decoder_input_ids shifted to the right by 1 
            
            outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=y_ids,
            decoder_attention_mask=decoder_attention_mask[:, :-1],
            use_cache=False, 
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,)
            
            sequence_output = outputs[0]
            
            lm_logits = F.linear(sequence_output, self.transformer.shared.weight, bias=self.final_logits_bias)
            outputs = (lm_logits,) + outputs[1:]  # Add cache, hidden states and attention if they are here
            loss_fct = nn.CrossEntropyLoss()

            masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1))
            outputs = (masked_lm_loss,) + outputs

            return outputs

    
    
        